##### libraries #####
import os
import sys
import pandas as pd
import yaml
from snakemake.utils import min_version

min_version("6.0.3")

SDIR = os.path.realpath(os.path.dirname(srcdir("Snakefile")))
shell.prefix(f"set -eo pipefail;")

##### setup report #####
report: os.path.join("report", "workflow.rst")

##### load config and sample annotation sheets #####
configfile: os.path.join("config", "config.yaml")

annot = pd.read_csv(config['sample_annotation'], index_col='sample_name')

result_path = os.path.join(config["result_path"],'scrnaseq_processing_seurat')

# gene list dictionary
gene_list_dict = config["module_gene_lists"] | config["vis_gene_lists"]


data_splits = ['merged']
if config["split_by"] is not None:
    for split in config["split_by"]:
        data_splits.extend(["{}_{}".format(split, value) for value in config["split_by"][split]])

all_steps = ['RAW','FILTERED','NORMALIZED','CORRECTED']
    
if config["stop_after"]=="CORRECTED":
    plot_steps = ["NORMALIZED","CORRECTED"]
    metadata_plot_steps = all_steps[:all_steps.index(config["stop_after"])]
elif config["stop_after"]=="NORMALIZED":
    plot_steps = ["NORMALIZED"]
    metadata_plot_steps = all_steps[:all_steps.index(config["stop_after"])+1]
else:
    plot_steps = []
    metadata_plot_steps = all_steps[:all_steps.index(config["stop_after"])+1]

rule all:
    input:
        final_objects = expand(os.path.join(result_path,'{split}','{step}_object.rds'), 
                               split=data_splits, 
                               step=config["stop_after"]),
        counts = expand(os.path.join(result_path,'{split}','{step}_RNA.csv'), 
                               split=data_splits, 
                               step=config["save_counts"]) if (len(config["save_counts"])>0) else [],
        metadata_plots = expand(os.path.join(result_path,'{split}','plots','{step}_metadata_{datatype}.png'),
                             split=data_splits,
                             step=metadata_plot_steps,
                             datatype=['numerical','categorical','types']
                            ),
        normalized_plots = expand(os.path.join(result_path,'{split}','plots','NORMALIZED_{plot_type}_{category}_{gene_list}.png'), 
                       split=data_splits,
                       plot_type=['ridge_plot','violin_plot','dot_plot','heatmap'], 
                       category=config["vis_categories"],
                       gene_list=list(config["vis_gene_lists"].keys())
                      ) if (config["stop_after"]=="CORRECTED")or(config["stop_after"]=="NORMALIZED") else [],
        corrected_plots = expand(os.path.join(result_path,'{split}','plots','CORRECTED_{plot_type}_{category}_{gene_list}.png'), 
                       split=data_splits,
                       plot_type=['ridge_plot','violin_plot','heatmap'], 
                       category=config["vis_categories"],
                       gene_list=list(config["vis_gene_lists"].keys())
                      ) if (config["stop_after"]=="CORRECTED") else [],
        envs = expand(os.path.join(config["result_path"],'envs','scrnaseq_processing_seurat','{env}.yaml'),env=['seurat','inspectdf']),
        gene_lists = expand(os.path.join(config["result_path"],'configs','scrnaseq_processing_seurat','{gene_list}.txt'),gene_list=list(gene_list_dict.keys())),
        configs = os.path.join(config["result_path"],'configs','scrnaseq_processing_seurat','{}_config.yaml'.format(config["project_name"])),
        annotations = os.path.join(config["result_path"],'configs','scrnaseq_processing_seurat','{}_annot.csv'.format(config["project_name"])),
    resources:
        mem_mb=config.get("mem", "16000"),
    threads: config.get("threads", 1)
    log:
        os.path.join("logs","rules","all.log"),
    params:
        partition=config.get("partition"),

        
##### load rules #####
include: os.path.join("rules", "common.smk")
include: os.path.join("rules", "process.smk")
include: os.path.join("rules", "normalize_correct_score.smk")
include: os.path.join("rules", "visualize.smk")
include: os.path.join("rules", "envs_export.smk")
